{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1geVKc8ESUph"
   },
   "source": [
    "# Face Anti Spoofing Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ItQsbhw6SQbN"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "82pbDU5xScdH"
   },
   "source": [
    "## Display function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hCCnoNDZSarC"
   },
   "outputs": [],
   "source": [
    "def imshow_np(img):\n",
    "    height,width,depth = img.shape\n",
    "    if depth==1:\n",
    "        img=img[:,:,0]\n",
    "    plt.imshow(img)\n",
    "    plt.show()\n",
    "\n",
    "def imshow(img):\n",
    "    imshow_np(img.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uv4WF0IpSf9W"
   },
   "source": [
    "## Data sets creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CDDNczBWSeOU"
   },
   "outputs": [],
   "source": [
    "#creation des donnees:\n",
    "\n",
    "#Images\n",
    "Images = np.load('images_sample.npz')\n",
    "\n",
    "#Changement de base\n",
    "Anchors = np.load('anchors_sample.npz')\n",
    "  \n",
    "#label_D:\n",
    "Labels_D = np.load('labels_D_sample.npz')\n",
    "\n",
    "#label_spoofing:\n",
    "Labels = np.load('label_sample.npz')\n",
    "\n",
    "#set:\n",
    "n=len(Images)\n",
    "\n",
    "data_images = np.zeros((n,256,256,3),dtype=np.float32)\n",
    "data_anchors = np.zeros((n,2,4096),dtype=np.float32)\n",
    "data_labels_D = np.zeros((n,32,32,1),dtype=np.float32)\n",
    "data_labels = np.zeros((n),dtype=np.float32)\n",
    "\n",
    "for item in Images.files:\n",
    "    data_images[int(item),:,:,:] = Images[item]\n",
    "    data_anchors[int(item),:,:] = Anchors[item]\n",
    "    data_labels_D[int(item),:,:,:] = Labels_D[item]\n",
    "    data_labels[int(item)] = Labels[item]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sj77kbWtSjXF"
   },
   "outputs": [],
   "source": [
    "training_part = 45/55\n",
    "n_train = int(n*training_part)\n",
    "\n",
    "#Training set\n",
    "data_images_train = data_images[:n_train,:,:,:]\n",
    "data_anchors_train = data_anchors[:n_train,:,:]\n",
    "data_labels_D_train = data_labels_D[:n_train,:,:,:]\n",
    "data_labels_train = data_labels[:n_train]\n",
    "\n",
    "#Test set\n",
    "data_images_test = data_images[n_train:,:,:,:]\n",
    "data_anchors_test = data_anchors[n_train:,:,:]\n",
    "data_labels_D_test = data_labels_D[n_train:,:,:,:]\n",
    "data_labels_test = data_labels[n_train:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K-OKsHuHShkO"
   },
   "outputs": [],
   "source": [
    "def prepare_dataloader_D(data_images_train,data_images_test,data_labels_D_train,data_labels_D_test):\n",
    "  \n",
    "    trainset_D = torch.utils.data.TensorDataset(torch.tensor(np.transpose(data_images_train, (0, 3, 1, 2))), torch.tensor(data_labels_D_train))\n",
    "    testset_D = torch.utils.data.TensorDataset(torch.tensor(np.transpose(data_images_test, (0, 3, 1, 2))), torch.tensor(data_labels_D_test))\n",
    "\n",
    "    trainloader_D = torch.utils.data.DataLoader(trainset_D,batch_size=5,shuffle=False)\n",
    "    testloader_D = torch.utils.data.DataLoader(testset_D,batch_size=5,shuffle=False)\n",
    "\n",
    "    return trainloader_D, testloader_D\n",
    "\n",
    "trainloader_D,testloader_D = prepare_dataloader_D(data_images_train,data_images_test,data_labels_D_train,data_labels_D_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CvFwUq3GSntJ"
   },
   "source": [
    "## Model creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 212
    },
    "colab_type": "code",
    "id": "-n37-Y8ASlXa",
    "outputId": "22749dab-e636-41e3-e8d7-b9d4f4878373"
   },
   "outputs": [],
   "source": [
    "import Anti_Spoof_net\n",
    "\n",
    "mon_model = Anti_Spoof_net.Anti_spoof_net()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(mon_model.parameters(),lr=3e-3,betas=(0.9, 0.999),eps=1e-08)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qcpVfUyjaEtB"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DIq5xKw9Spnx"
   },
   "outputs": [],
   "source": [
    "def train_CNN(net, optimizer, trainloader, data_anchors, criterion, n_epoch = 10):\n",
    "    total = 0\n",
    "    for epoch in range(n_epoch):\n",
    "        # loop over the dataset multiple times\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            #Pre-created Data:\n",
    "            images, labels_D = data\n",
    "            # training step\n",
    "            optimizer.zero_grad()\n",
    "            outputs_D, _ = net(images,False,data_anchors[i:i+5,:,:])\n",
    "            #handle NaN:\n",
    "            if (torch.norm((outputs_D != outputs_D).float())==0):   \n",
    "                if (i%50==0 or i%50==1):\n",
    "                    imshow_np(np.transpose(images[0,:,:,:].numpy(),(1,2,0)))\n",
    "                    imshow_np(np.transpose(outputs_D[0,:,:,:].detach().numpy(),(1,2,0)))\n",
    "                loss = criterion(outputs_D, labels_D)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                # compute statistics\n",
    "                total += labels_D.size(0)\n",
    "                running_loss += loss.item()\n",
    "                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))\n",
    "        print('Epoch finished')\n",
    "    print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "u7M6F3l7TOVK"
   },
   "outputs": [],
   "source": [
    "def train_RNN(net, optimizer, trainloader, anchors, labels, criterion, n_epoch = 10):\n",
    "    total = 0\n",
    "    for epoch in range(n_epoch):\n",
    "    # loop over the dataset multiple times\n",
    "    running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            #Donnees pre-crees:\n",
    "            images, labels_D = data\n",
    "            # training step\n",
    "            optimizer.zero_grad()\n",
    "            _, outputs_F = net(images,False,anchors[i:i+1,:,:])\n",
    "            #handle NaN:\n",
    "            if (torch.norm((outputs_F != outputs_F).float())==0):   \n",
    "                if (i%50==0 or i%50==1):\n",
    "                    imshow_np(np.transpose(images[0,:,:,:].numpy(),(1,2,0)))\n",
    "                    print('F:')\n",
    "                    print(outputs_F)\n",
    "                if labels[i*5]==0: #toutes les images du batch proviennent de la même vidéo\n",
    "                    label=torch.zeros((5,1,2),dtype=torch.float32)\n",
    "                else:\n",
    "                    label=torch.ones((5,1,2),dtype=torch.float32)\n",
    "                loss = criterion(outputs_F, label)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                # compute statistics\n",
    "                total += labels_D.size(0)\n",
    "                running_loss += loss.item()\n",
    "                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))\n",
    "        print('Epoch finished')\n",
    "    print('Finished Training')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### For the overall model training, we alternatively train the CNN part and the CNN/RNN part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1522
    },
    "colab_type": "code",
    "id": "PH8qENwIRjJL",
    "outputId": "7995ee05-d5c6-470d-985a-a9fa9167f6e8"
   },
   "outputs": [],
   "source": [
    "def train_All(net, optimizer, trainloader, anchors, labels, criterion, n_epoch = 10):\n",
    "    for i in range(n_epoch):\n",
    "        train_CNN(net, optimizer, trainloader_D, data_anchors_train, criterion, n_epoch = 1)\n",
    "        torch.save(net,'mon_model')\n",
    "        train_RNN(net, optimizer, trainloader_D, data_anchors_train, data_labels_train, criterion, n_epoch = 1)\n",
    "        torch.save(net,'mon_model')\n",
    "    \n",
    "mon_model = torch.load('mon_model')\n",
    "outputs = train_All(mon_model, optimizer, trainloader_D,data_anchors_train, data_labels_train, criterion, n_epoch = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "U9TwoiVeSrU1"
   },
   "outputs": [],
   "source": [
    "mon_model = torch.load('mon_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "GYtBJjrOS_Af",
    "outputId": "fb99df55-be66-4958-929f-9c702be42377"
   },
   "outputs": [],
   "source": [
    "def accuracy(net, criterion, testloader, labels):\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    l=0.015\n",
    "    for i, (images, _ ) in enumerate(testloader,0):\n",
    "        outputs_D, outputs_F = net(images,False,data_anchors_test[i:i+1,:,:])\n",
    "        critere = torch.norm(outputs_D)+l*torch.norm(outputs_F)\n",
    "        #We will take 850 as offset\n",
    "        if (critere>850 and label==1):\n",
    "            correct+=1\n",
    "        if (critere<850 and label==0):\n",
    "            correct==1\n",
    "        total+=1\n",
    "        print(correct/total)\n",
    "    accuracy =  correct / total\n",
    "    loss = loss/total\n",
    "    return accuracy, loss\n",
    "\n",
    "accuracy(mon_model, criterion, testloader_D, data_labels_test)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Project_Spoofing.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
